{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "I_2WGU5FlcLQ"
      },
      "source": [
        "# Loading models from \"[Why Do Better Loss Functions Lead to Less Transferable Features?](https://arxiv.org/abs/2010.16402)\"\n",
        "\n",
        "Copyright 2022 Google LLC\n",
        "\n",
        "Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "you may not use this file except in compliance with the License.\n",
        "You may obtain a copy of the License at\n",
        "\n",
        "    https://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "Unless required by applicable law or agreed to in writing, software\n",
        "distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "See the License for the specific language governing permissions and\n",
        "limitations under the License."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "h5tjHuaYagOg"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "import numpy as np\n",
        "\n",
        "!git clone https://github.com/google-research/google-research.git google_research --depth=1\n",
        "from google_research.loss_functions_transfer import resnet_preprocessing\n",
        "from google_research.loss_functions_transfer import load_model\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QBPa3cM0KcO7"
      },
      "source": [
        "# Loading and running the SavedModels.\n",
        "\n",
        "\n",
        "See LOSS_HYPERPARAMETERS in [load_model.py](https://github.com/google-research/google-research/blob/master/loss_functions_transfer/load_model.py)\n",
        "for valid choices of `loss_name`. There are 8 seeds per loss function, so seed; seed should be in `{0, ..., 7}`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EksBmfjOzYNW"
      },
      "outputs": [],
      "source": [
        "import tensorflow_datasets as tfds\n",
        "from tqdm import tqdm\n",
        "\n",
        "loss_name = 'softmax'\n",
        "seed = 0\n",
        "\n",
        "# Load SavedModel.\n",
        "model = tf.saved_model.load(\n",
        "    f'{load_model.BASE_PATH}/savedmodels/{loss_name}/seed{seed}')\n",
        "model = model.signatures['serving_default']\n",
        "\n",
        "# Construct dataset.\n",
        "dataset = tfds.load('imagenet_v2', split='test')\n",
        "# For this Colab, we use ImageNet-V2 for eval because it can be downloaded w/o\n",
        "# manual intervention. To use ImageNet instead, uncomment below.\n",
        "# dataset = tfds.load('imagenet2012', split='validation')\n",
        "\n",
        " # Set up preprocessing.\n",
        "def preprocess_fn(example):\n",
        "  return {\n",
        "      'image': resnet_preprocessing.preprocess_image(\n",
        "          example['image'], 224, 224, is_training=False),\n",
        "      'label': example['label']\n",
        "  }\n",
        "dataset = dataset.map(preprocess_fn).batch(128, drop_remainder=False)\n",
        "\n",
        "# Perform evaluation.\n",
        "correct = 0\n",
        "n = 0\n",
        "for batch in tqdm(dataset):\n",
        "  # The model returns the output of each block as well as the average pooling,\n",
        "  # layer, but we use only the final outputs here.\n",
        "  outputs = model(image=batch['image'])['outputs']\n",
        "  # The model outputs 1001 classes, but ImageNet contains only 1000. The first\n",
        "  # class is an unused background class. We drop it when evaluating.\n",
        "  outputs = outputs[:, 1:]\n",
        "  predictions = tf.argmax(outputs, -1)\n",
        "  correct += tf.reduce_sum(tf.cast(predictions == batch['label'], tf.int64))\n",
        "  n += batch['label'].shape[0]\n",
        "print()\n",
        "print(f'Accuracy: {correct / n * 100:.1f}%')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4dv7jLPmKZiq"
      },
      "source": [
        "# Constructing, restoring, and running the checkpoint in graph mode."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mniGaPcoqJee"
      },
      "outputs": [],
      "source": [
        "import numpy as np\n",
        "\n",
        "loss_name = 'logit_penalty'\n",
        "seed = 0\n",
        "\n",
        "# Images should be preprocessed using:\n",
        "# resnet_preprocessing.preprocess_image(image, 224, 224, is_training=False)\n",
        "# (See cell above for an example using tensorflow_datasets.)\n",
        "inputs_np = np.zeros((1, 224, 224, 3))\n",
        "labels_np = np.pad([[1]], ((0, 0), (1000, 0)))\n",
        "\n",
        "with tf.Graph().as_default():\n",
        "  inputs = tf.compat.v1.placeholder(tf.float32, (None, 224, 224, 3))\n",
        "  labels = tf.compat.v1.placeholder(tf.float32, (None, 1001))\n",
        "  loss, endpoints = load_model.build_model_and_compute_loss(\n",
        "      loss_name=loss_name, inputs=inputs, labels=labels,\n",
        "      is_training=False)\n",
        "  with tf.compat.v1.Session() as sess:\n",
        "    load_model.restore_checkpoint(loss_name, seed, sess)\n",
        "    loss_np, endpoints_np = sess.run((loss, endpoints),\n",
        "             feed_dict={inputs: inputs_np, labels: labels_np})\n",
        "\n",
        "{k: v.shape for k, v in endpoints_np.items()}"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "private_outputs": true,
      "provenance": [
        {
          "file_id": "1wNQ3fC1njyoCaszagQCXVuDg8IPXJnLF",
          "timestamp": 1669410241552
        }
      ]
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
